{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "InterFaceGAN",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qJDJLE3v0HNr"
      },
      "source": [
        "# Fetch Codebase and Models"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JqiWKjpFa0ov"
      },
      "source": [
        "import os\n",
        "os.chdir('/content')\n",
        "CODE_DIR = 'interfacegan'\n",
        "!git clone https://github.com/genforce/interfacegan.git $CODE_DIR\n",
        "os.chdir(f'./{CODE_DIR}')\n",
        "!wget https://www.dropbox.com/s/t74z87pk3cf8ny7/pggan_celebahq.pth?dl=1 -O models/pretrain/pggan_celebahq.pth --quiet\n",
        "!wget https://www.dropbox.com/s/nmo2g3u0qt7x70m/stylegan_celebahq.pth?dl=1 -O models/pretrain/stylegan_celebahq.pth --quiet\n",
        "!wget https://www.dropbox.com/s/qyv37eaobnow7fu/stylegan_ffhq.pth?dl=1 -O models/pretrain/stylegan_ffhq.pth --quiet"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hQ_IXBZr8YcJ"
      },
      "source": [
        "# Define Utility Functions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ijKTlG5GeTd3"
      },
      "source": [
        "import os.path\n",
        "import io\n",
        "import IPython.display\n",
        "import numpy as np\n",
        "import cv2\n",
        "import PIL.Image\n",
        "\n",
        "import torch\n",
        "\n",
        "from models.model_settings import MODEL_POOL\n",
        "from models.pggan_generator import PGGANGenerator\n",
        "from models.stylegan_generator import StyleGANGenerator\n",
        "from utils.manipulator import linear_interpolate\n",
        "\n",
        "\n",
        "def build_generator(model_name):\n",
        "  \"\"\"Builds the generator by model name.\"\"\"\n",
        "  gan_type = MODEL_POOL[model_name]['gan_type']\n",
        "  if gan_type == 'pggan':\n",
        "    generator = PGGANGenerator(model_name)\n",
        "  elif gan_type == 'stylegan':\n",
        "    generator = StyleGANGenerator(model_name)\n",
        "  return generator\n",
        "\n",
        "\n",
        "def sample_codes(generator, num, latent_space_type='Z', seed=0):\n",
        "  \"\"\"Samples latent codes randomly.\"\"\"\n",
        "  np.random.seed(seed)\n",
        "  codes = generator.easy_sample(num)\n",
        "  if generator.gan_type == 'stylegan' and latent_space_type == 'W':\n",
        "    codes = torch.from_numpy(codes).type(torch.FloatTensor).to(generator.run_device)\n",
        "    codes = generator.get_value(generator.model.mapping(codes))\n",
        "  return codes\n",
        "\n",
        "\n",
        "def imshow(images, col, viz_size=256):\n",
        "  \"\"\"Shows images in one figure.\"\"\"\n",
        "  num, height, width, channels = images.shape\n",
        "  assert num % col == 0\n",
        "  row = num // col\n",
        "\n",
        "  fused_image = np.zeros((viz_size * row, viz_size * col, channels), dtype=np.uint8)\n",
        "\n",
        "  for idx, image in enumerate(images):\n",
        "    i, j = divmod(idx, col)\n",
        "    y = i * viz_size\n",
        "    x = j * viz_size\n",
        "    if height != viz_size or width != viz_size:\n",
        "      image = cv2.resize(image, (viz_size, viz_size))\n",
        "    fused_image[y:y + viz_size, x:x + viz_size] = image\n",
        "\n",
        "  fused_image = np.asarray(fused_image, dtype=np.uint8)\n",
        "  data = io.BytesIO()\n",
        "  PIL.Image.fromarray(fused_image).save(data, 'jpeg')\n",
        "  im_data = data.getvalue()\n",
        "  disp = IPython.display.display(IPython.display.Image(im_data))\n",
        "  return disp"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q7gkmrVW8eR1"
      },
      "source": [
        "# Select a Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NoWI4fPQ6Gnf"
      },
      "source": [
        "#@title { display-mode: \"form\", run: \"auto\" }\n",
        "model_name = \"stylegan_ffhq\" #@param ['pggan_celebahq','stylegan_celebahq', 'stylegan_ffhq']\n",
        "latent_space_type = \"W\" #@param ['Z', 'W']\n",
        "\n",
        "generator = build_generator(model_name)\n",
        "\n",
        "ATTRS = ['age', 'eyeglasses', 'gender', 'pose', 'smile']\n",
        "boundaries = {}\n",
        "for i, attr_name in enumerate(ATTRS):\n",
        "  boundary_name = f'{model_name}_{attr_name}'\n",
        "  if generator.gan_type == 'stylegan' and latent_space_type == 'W':\n",
        "    boundaries[attr_name] = np.load(f'boundaries/{boundary_name}_w_boundary.npy')\n",
        "  else:\n",
        "    boundaries[attr_name] = np.load(f'boundaries/{boundary_name}_boundary.npy')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zDStH1O5t1KC"
      },
      "source": [
        "# Sample latent codes"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qlRGKZbJt9hA"
      },
      "source": [
        "#@title { display-mode: \"form\", run: \"auto\" }\n",
        "\n",
        "num_samples = 4 #@param {type:\"slider\", min:1, max:8, step:1}\n",
        "noise_seed = 0 #@param {type:\"slider\", min:0, max:1000, step:1}\n",
        "\n",
        "latent_codes = sample_codes(generator, num_samples, latent_space_type, noise_seed)\n",
        "if generator.gan_type == 'stylegan' and latent_space_type == 'W':\n",
        "  synthesis_kwargs = {'latent_space_type': 'W'}\n",
        "else:\n",
        "  synthesis_kwargs = {}\n",
        "\n",
        "images = generator.easy_synthesize(latent_codes, **synthesis_kwargs)['image']\n",
        "imshow(images, col=num_samples)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MmRPN3xz8jCH"
      },
      "source": [
        "# Edit facial attributes"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ccONBF60mVir"
      },
      "source": [
        "#@title { display-mode: \"form\", run: \"auto\" }\n",
        "\n",
        "age = 0 #@param {type:\"slider\", min:-3.0, max:3.0, step:0.1}\n",
        "eyeglasses = 0 #@param {type:\"slider\", min:-2.9, max:3.0, step:0.1}\n",
        "gender = 0 #@param {type:\"slider\", min:-3.0, max:3.0, step:0.1}\n",
        "pose = 0 #@param {type:\"slider\", min:-3.0, max:3.0, step:0.1}\n",
        "smile = 0 #@param {type:\"slider\", min:-3.0, max:3.0, step:0.1}\n",
        "\n",
        "new_codes = latent_codes.copy()\n",
        "for i, attr_name in enumerate(ATTRS):\n",
        "  new_codes += boundaries[attr_name] * eval(attr_name)\n",
        "\n",
        "new_images = generator.easy_synthesize(new_codes, **synthesis_kwargs)['image']\n",
        "imshow(new_images, col=num_samples)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}